
import os
import re
import sys
import base64
import openai
import subprocess
from pathlib import Path
from threading import Timer
import inference
import random

model = "GPT4o"
PARAMS = {
    "temperature": 0.5,      # Sampling temperature
    "num_samples": 8,       # Number of valid diagram code solutions we want
    "max_attempts": 80,      # (optional) total upper bound on attempts if desired
    "execution_timeout": 20, # seconds to allow the code to run
    "mem_capacity": 5,
    "outer_loops": 16,        # N: number of outer loops
    "inner_loops": 5        # M: number of inner loops (thus total attempts = N*M = 2*5 = 10)
}

def read_file(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()

def write_file(path, content):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        f.write(content)

def load_domain_info(domain_name):
    base_dir = os.path.join(domain_name)
    problem_description = read_file(os.path.join(base_dir, f"{domain_name}_domain.txt"))
    initial_state = read_file(os.path.join(base_dir, "initial_state.txt"))
    best_diagram_encoding_path = os.path.join(base_dir, "one_shot", "ini_diagram_encoding", "best_candidate.txt")
    best_diagram_encoding = read_file(best_diagram_encoding_path)
    return problem_description, initial_state, best_diagram_encoding

def run_matplotlib_code(code_str, output_png_path):
    out_dir = os.path.dirname(output_png_path)
    os.makedirs(out_dir, exist_ok=True)

    # Write the code to a temporary .py file
    code_file_path = os.path.join(out_dir, "matplot_code.py")
    with open(code_file_path, 'w', encoding='utf-8') as f:
        f.write(code_str)

    # Launch in subprocess with a time limit
    cmd = ["python", code_file_path]
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    def kill_proc(p):
        p.kill()

    timer = Timer(PARAMS["execution_timeout"], kill_proc, [proc])
    try:
        timer.start()
        stdout, stderr = proc.communicate()
    finally:
        timer.cancel()

    if proc.returncode != 0:
        return f"Execution error: {stderr.decode('utf-8')}"
    else:
        # Check that PNG was created and is non-empty
        if not os.path.isfile(output_png_path) or os.path.getsize(output_png_path) == 0:
            return "No PNG generated or empty PNG file."
        return None

def verify_generated_diagram(problem_description, initial_state, diagram_encoding, diagram_png_path, diagram_reasoning, model):
    if not os.path.isfile(diagram_png_path):
        return False, "PNG file not found or missing."

    # We pass the PNG as an {"image_path": diagram_png_path} to get_model_response
    # consistent with your pipeline style. 
    prompt_parts = [f"""
    Consider the following problem description:
    {problem_description}

    This is the initial state of the problem:
    {initial_state}

    This is the diagram encoding of the state, describing shape, position, size, and status of each object:
    {diagram_encoding}

    We have generated the following diagram:""",
    {"image_path": diagram_png_path},
    "The following is the reaosning behind the diargam. It includes an explanation for the meaning behind each shape, color, different sizes and positions:",
    diagram_reasoning,
    """Analyze this diagram for correctness and clarity:
    - Does the diagram accurately visualize all objects as described in the initial state and the diagram encoding 
    with correct shapes, relative positions, and relative sizes? 
    - Does the reasoning match what is shown in the diagram?
    - Does each object in the diagram have a clear text identifier/lable inside it?
    - Are all of the objects in the diagram encoding clearly visiable in the diagram? For this you have make sure no 2 objects overlap and also the legend should not significantly overlay any objects
    - Is the status and textual lable of each object included inside its shape? Are the labels and statuses clear, easily readable, and with high contrast? Is it obvious which status/label belongs to which object?
    - Is the visualization of different objects of the same type consistent in the diagram? (e.g. if there are 10 blcoks in the diagram, are all blcoks visualized using rectangles of the same size?)
    - Does the diagram appear physically plausible (e.g., no floating objects if not approperiate, no misaliged cells in a grid, etc)?
    - Is the status of the objects acurately visualized with respect to the legend and the reasoning?
    
    Finally, provide "yes" if the diagram is correct, accurate, clear, and physically plausible, "no" otherwise, in the following format:
    ```yes_no
    <yes or no>
    ```
    If your final answer is "no", provide a short phrase describing the problem with the diagram in the following format:
    ```error
    <error description>
    ```
    """]

    response = inference.get_model_response(prompt_parts, model)
    validity = inference.extract_content(response, "yes_no")
    error_description = inference.extract_content(response, "error")

    return validity, error_description

def rank_diagram_images(problem_description, initial_state, diagram_png_paths, diagram_encoding):
    
    prompt_parts = [
    f"""
    We have the following problem:
    {problem_description}

    Initial state:
    {initial_state}
    
    This is the diagram encoding of the state, describing shape, position, size, and status of each object:
    {diagram_encoding}

    We generated some candidate diagrams in PNG format. Below, we show each of them one by one. 
    Please rank them from best to worst based on:
    - accuracy in visualizing the initial state (including object placement, correct deption of relative size and positions)
    - how intuitive is the diagram (Which visualization makes more sense semantically with respect to the domain?) 
    - a diagram is better if status of a object is accurately visualized (e.g. by using a color) compared to a diagram that only uses text
    - how clear and easily understandable the diagram is 
    - physical plausibility (does the layout of the scene make sense, wether there are any laws of physics violated in the diagram)
    - readability of text
    - minimal overlaps
    """
    ]

    for idx, path in enumerate(diagram_png_paths):
        prompt_parts.append(f"\n\n Diagram {idx+1}:")
        prompt_parts.append({"image_path": path})

    prompt_parts.append("""
    Provide your ranking in the format:

    ```ranking
    <list the diagrams in order from best to worst by their index, e.g. 2,1,3>
    ```

    Above the code block, iterate through each diagram and write statement about its weaknesses and strengths in the original order given. Then rank the diagrams. Think step by step about the ranking and explain.
    """)

    response = inference.get_model_response(prompt_parts, model)

    # The entire response is the "reasoning"
    reasoning = response
    ranking = inference.extract_content(response, "ranking")

    # If ranking_block is None or empty, default to reverse of the original order (ie last is best)
    if not ranking:
        return diagram_png_paths[::-1], reasoning

    ranked_ids = [int(id.strip()) for id in ranking if id.strip().isdigit()]
    
    #len(ranked_ids) == len(diagram_png_paths) and
    if all(1 <= i <= len(diagram_png_paths) for i in ranked_ids):
        # Convert 1-based to 0-based
        new_order = [diagram_png_paths[i-1] for i in ranked_ids]
        return new_order, reasoning
    else:
        # fallback
        print("failure in extracting the ranking returned by the model")
        return diagram_png_paths, reasoning

def generate_diagram_code(
    problem_description,
    initial_state,
    prev_section,
    diagram_encoding,
    output_path,
    prev_error,
    temp
):
    prompt_string = [
    f"""
    You are creating Python matplotlib code that draws a diagram for the following initial state of a problem:

    Problem Description:
    {problem_description}

    Initial State:
    {initial_state}

    The following is a collection of descriptions of the relative/absolute position, relative/absolute size, 
    status (includes color if relevant), and text identifier for each object in the initial state's diagram encoding:
    {diagram_encoding}

    We want the final code to:
    1) Draw geometric shapes (rectangles, circles, arrows, etc.) for each object as described by the initial state and the diagram encoding. (For this ensure no two objects overlap)
    2) Place text labels and status for each object inside their shape. Make sure any text is clear and readable. For statuses, pick very short but descriptiive phrases, rather than clear/not clear.
    3) Ensure the final line of code saves the figure to:
    {output_path}
    4) If possible visualize the status of the objects and create a legend (hypothetical example: a legend mapping red to not clear status and green to a clear color status) (you must include the textal identifier/label and textual status of each object (includes grid cells if applicable) inside each its shape in any case) (Make sure legend does not overlay any objects of the diagram)
    5) Make sure the status of the objects is accurately visualized with respect to the legend. And also make sure the legend is not overlaying any of the objects int he diagram.
    6) Make sure the contrast between any text in the diagram and it's background is as high as possible to increase readability. You can use \n for text labels and statuses to avoid overlap and save space horizontally.
    7) Make sure each object has a clearly visibla etext lable and status.""",
    (f"""Below are previously generated code snippets, You have to make sure your code is meaningfully different from the generations below:
    {prev_section}
    """ if prev_section else ""),
    (f"Additionally, avoid repeating the error: {prev_error}" if prev_error else ""),
    """
    Start by explaining your reasoning behind how you plan to encode each object in the diagram. This includes defining the meaning of each shape, color, size, or any other visual element. If you use colors to represent the status of objects, provide a detailed explanation of what each color signifies and the conditions under which an object is assigned a specific color. Ensure that someone unfamiliar with the domain can fully understand the diagram based on your description.
    Important: the reasoning should not include any information about the specific insatnce above, and must be applicalble to visualizations of any given instance int eh domain above. Only explain the meaning behind the colors, shapes, relative locations and sizes.
    Provide your reasoning using the following format:
    ```reasoning
    <detailed description of what each object, color, legend, size or different location int he diagram means>
    ```
    Finally, provide your new, unique Python code in this format:
    ```code
    <python code>
    ```
    """
    ]
    response = inference.get_model_response(prompt_string, model, temp)

    diargam_reasoning = inference.extract_content(response, "reasoning", remove_new_lines=False)
    code_str = inference.extract_content(response, "code", remove_new_lines=False)

    return code_str, diargam_reasoning

def get_1shot_diagram_code_ini(domain_name):
    problem_description, initial_state, best_diagram_encoding = load_domain_info(domain_name)

    base_task_dir = os.path.join(domain_name, "one_shot", "ini_diagram_code")
    attempts_dir = os.path.join(base_task_dir, "attempts")
    best_dir = base_task_dir
    os.makedirs(attempts_dir, exist_ok=True)
    os.makedirs(best_dir, exist_ok=True)

    # Collect the single best from each outer loop here:
    valid_diagram_codes = []
    valid_diagram_pngs = []
    valid_diagram_reasonings = []

    attempt_count = 0

    # Outer loop
    for outer_idx in range(PARAMS["outer_loops"]):
        print(f"[INFO] Outer loop #{outer_idx+1} starting")

        # Keep memory of code and error messages for the inner loop attempts
        prev_solutions_text = []
        prev_error = None

        # For each outer iteration, store valid attempts locally first
        this_inner_codes = []
        this_inner_pngs = []
        this_inner_reasonings = []

        # Inner loop
        for inner_idx in range(PARAMS["inner_loops"]):
            # Construct the previous code snippets section (memory)
            prev_section = "\n".join([
                f"-- A previously generated code snippet:\n{code}\n"
                for code in prev_solutions_text[:PARAMS["mem_capacity"]]
            ])

            attempt_count += 1
            output_png_path = os.path.join(attempts_dir, f"attempt_{attempt_count}.png")

            # Generate code + reasoning
            new_code, diargam_reasoning = generate_diagram_code(
                problem_description,
                initial_state,
                prev_section,
                best_diagram_encoding,
                output_png_path,
                prev_error,
                PARAMS["temperature"]
            )

            print(f"[INFO] Generated code attempt #{attempt_count} (outer={outer_idx+1}, inner={inner_idx+1})")
            code_file_path = os.path.join(attempts_dir, f"attempt_{attempt_count}.py")
            write_file(code_file_path, new_code)

            # Execute the code to produce a PNG
            error = run_matplotlib_code(new_code, output_png_path)
            if error:
                prev_error = (
                    f"Error from the following attempt:\n{new_code}\nThe error received: {error}"
                )
                with open(code_file_path, 'a', encoding='utf-8') as f:
                    f.write(f"\n# EXECUTION ERROR:\n# {error}\n")
                print(f"[ERROR] Attempt {attempt_count} failed execution: {error}")
                continue

            # Verify diagram correctness
            is_valid, err_msg = verify_generated_diagram(
                problem_description,
                initial_state,
                best_diagram_encoding,
                output_png_path,
                diargam_reasoning,
                model
            )
            if not is_valid:
                prev_error = (
                    f"Verification failed for the following attempt:\n{new_code}\nThe error received: {err_msg}"
                )
                with open(code_file_path, 'a', encoding='utf-8') as f:
                    f.write(f"\n# VERIFICATION FAILED:\n# {err_msg}\n")
                print(f"[ERROR] Attempt {attempt_count} verification failed: {err_msg}")
                continue
            else:
                prev_error = None
                print(f"[INFO] Attempt {attempt_count} successfully verified!")
                # Store into local list for this inner iteration
                this_inner_codes.append(new_code)
                this_inner_pngs.append(output_png_path)
                this_inner_reasonings.append(diargam_reasoning)
                # Also store in memory to avoid repeating similar code
                prev_solutions_text.append(new_code)

        # End of inner loop: if we have any valid solutions from the inner loop, pick the best
        if this_inner_codes:
            if len(this_inner_codes) > 1:
                # rank them
                ranked_pngs, local_reasoning = rank_diagram_images(
                    problem_description, initial_state, this_inner_pngs, best_diagram_encoding
                )
                best_inner_png = ranked_pngs[0]
                best_inner_idx = this_inner_pngs.index(best_inner_png)

                best_inner_code = this_inner_codes[best_inner_idx]
                best_inner_reasoning = this_inner_reasonings[best_inner_idx]
            else:
                # Only one valid
                best_inner_code = this_inner_codes[0]
                best_inner_png = this_inner_pngs[0]
                best_inner_reasoning = this_inner_reasonings[0]

            if len(this_inner_codes) > 0:
                # Save the best from this inner loop in the attempts folder
                best_code_file = os.path.join(attempts_dir, f"best_attempt_inner_loop{outer_idx+1}.py")
                best_png_file = os.path.join(attempts_dir, f"best_attempt_inner_loop{outer_idx+1}.png")
                best_reasoning_file = os.path.join(attempts_dir, f"best_attempt_inner_loop{outer_idx+1}_reasoning.txt")

                write_file(best_code_file, best_inner_code)
                write_file(best_reasoning_file, best_inner_reasoning)
                with open(best_inner_png, 'rb') as src, open(best_png_file, 'wb') as dst:
                    dst.write(src.read())

                # Add the best from this outer iteration's inner loop results to the global valid list
                valid_diagram_codes.append(best_inner_code)
                valid_diagram_pngs.append(best_png_file)
                valid_diagram_reasonings.append(best_inner_reasoning)

            # If we have enough total best samples, stop outer loops
            if len(valid_diagram_codes) >= PARAMS["num_samples"]:
                print(f"[INFO] Reached desired number of samples ({PARAMS['num_samples']}). Stopping.")
                break

        # Break from outer loop if we already have enough
        if len(valid_diagram_codes) >= PARAMS["num_samples"]:
            break

    # Finished all outer loops or broke early
    if not valid_diagram_codes:
        error = "[ERROR] No valid diagram codes found; pipeline ended with no successes."
        print(error)
        return False, error

    # Now do a final overall ranking
    combined = list(zip(valid_diagram_codes, valid_diagram_pngs, valid_diagram_reasonings))
    random.shuffle(combined)
    valid_diagram_codes, valid_diagram_pngs, valid_diagram_reasonings = [list(t) for t in zip(*combined)]

    print("[INFO] Final ranking of collected best diagrams...")
    ranked_pngs, reasoning = rank_diagram_images(
        problem_description, initial_state, valid_diagram_pngs, best_diagram_encoding
    )

    best_png = ranked_pngs[0]
    best_idx = valid_diagram_pngs.index(best_png)
    best_code = valid_diagram_codes[best_idx]
    best_reasoning = valid_diagram_reasonings[best_idx]

    best_code_file = os.path.join(best_dir, "best_candidate_code.py")
    write_file(best_code_file, best_code)

    best_reasoning_file = os.path.join(best_dir, "best_candidate_reasoning.txt")
    write_file(best_reasoning_file, best_reasoning)

    best_diagram_file = os.path.join(best_dir, "best_diagram.png")
    with open(best_png, 'rb') as src, open(best_diagram_file, 'wb') as dst:
        dst.write(src.read())

    print("[SUCCESS] At least one valid diagram code solution found.")
    print(f"[INFO] Best code saved at: {best_code_file}")
    print(f"[INFO] Best diagram saved at: {best_diagram_file}")

    return True, None


def main():
    if len(sys.argv) < 2:
        print("Usage: python one_shot_diagram_code.py <domain_name>")
        sys.exit(1)

    domain_name = sys.argv[1]
    print("Started on getting verified and ranked diagram code for the initial state")
    get_1shot_diagram_code_ini(domain_name)

if __name__ == "__main__":
    main()